{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T17:59:03.644207Z",
     "start_time": "2020-05-21T17:59:03.635658Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style> td {font-size: 18px} </style>\n",
       "<style> tr {font-size: 18px} </style>\n",
       "<style> li {font-size: 18px} </style>\n",
       "<style> img {height: 50%} </style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%HTML\n",
    "<style> td {font-size: 18px} </style>\n",
    "<style> tr {font-size: 18px} </style>\n",
    "<style> li {font-size: 18px} </style>\n",
    "<style> img {height: 50%} </style>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T17:59:04.937164Z",
     "start_time": "2020-05-21T17:59:04.897061Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    }
   ],
   "source": [
    "from utils import print_data_stats, load_data, flat_accuracy, format_time, subset_data\n",
    "\n",
    "import pandas as pd\n",
    "pd.set_option('precision', 3)\n",
    "\n",
    "import random\n",
    "random.seed(11)\n",
    "\n",
    "import torch\n",
    "device = torch.device('cuda') if torch. cuda.is_available() else 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "LANGS = [\"ar\",\"en\",\"es\",\"ru\",\"zh\"]\n",
    "LANGS_MAPPING = {\"en\":\"english\",\"es\":\"spanish\",\"ru\":\"russian\",\"ar\":\"arabic\",\"zh\":\"chinese\"}\n",
    "\n",
    "data = load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# mBERT\n",
    "![BERT](https://yashuseth.files.wordpress.com/2019/06/fig1-1.png)\n",
    "<!-- ![title](https://miro.medium.com/max/1400/0*lBYVNRe1esIXn1qE.png) -->\n",
    "**BERT**: Bidirectional Encoder Representations for Transformers  \n",
    "**mBERT**: BERT pre-trained from monolingual corpora in 104 languages\n",
    "\n",
    "- Commonly used for cross-lingual transfer these days\n",
    "- [A Primer in BERTology: What we know about how BERT works](https://arxiv.org/abs/2002.12327)\n",
    "- [How multilingual is Multilingual BERT?](https://arxiv.org/abs/1906.01502)\n",
    "- [The Illustrated Transformers](http://jalammar.github.io/illustrated-transformer/)\n",
    "- [Huggingface multilingual models intro](https://huggingface.co/transformers/v2.2.0/multilingual.html)\n",
    "- Codes below are substantially borrowed from [this blog post](https://mccormickml.com/2019/07/22/BERT-fine-tuning/) by Chris McCormick and Nick Ryan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T18:01:21.775783Z",
     "start_time": "2020-05-21T18:01:06.824535Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import BertForSequenceClassification, BertTokenizer, AdamW\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-uncased\", cache_dir=\"../transformer-models/\", do_lower_case=True)\n",
    "model = BertForSequenceClassification.from_pretrained(\"bert-base-multilingual-uncased\", num_labels = 2, cache_dir=\"../transformer-models/\", output_attentions = False, output_hidden_states = False).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T16:35:47.217301Z",
     "start_time": "2020-05-21T16:35:47.205921Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The BERT model has 201 different named parameters.\n",
      "\n",
      "==== Embedding Layer ====\n",
      "\n",
      "bert.embeddings.word_embeddings.weight                  (105879, 768)\n",
      "bert.embeddings.position_embeddings.weight                (512, 768)\n",
      "bert.embeddings.token_type_embeddings.weight                (2, 768)\n",
      "bert.embeddings.LayerNorm.weight                              (768,)\n",
      "bert.embeddings.LayerNorm.bias                                (768,)\n",
      "\n",
      "==== First Transformer ====\n",
      "\n",
      "bert.encoder.layer.0.attention.self.query.weight          (768, 768)\n",
      "bert.encoder.layer.0.attention.self.query.bias                (768,)\n",
      "bert.encoder.layer.0.attention.self.key.weight            (768, 768)\n",
      "bert.encoder.layer.0.attention.self.key.bias                  (768,)\n",
      "bert.encoder.layer.0.attention.self.value.weight          (768, 768)\n",
      "bert.encoder.layer.0.attention.self.value.bias                (768,)\n",
      "bert.encoder.layer.0.attention.output.dense.weight        (768, 768)\n",
      "bert.encoder.layer.0.attention.output.dense.bias              (768,)\n",
      "bert.encoder.layer.0.attention.output.LayerNorm.weight        (768,)\n",
      "bert.encoder.layer.0.attention.output.LayerNorm.bias          (768,)\n",
      "bert.encoder.layer.0.intermediate.dense.weight           (3072, 768)\n",
      "bert.encoder.layer.0.intermediate.dense.bias                 (3072,)\n",
      "bert.encoder.layer.0.output.dense.weight                 (768, 3072)\n",
      "bert.encoder.layer.0.output.dense.bias                        (768,)\n",
      "bert.encoder.layer.0.output.LayerNorm.weight                  (768,)\n",
      "bert.encoder.layer.0.output.LayerNorm.bias                    (768,)\n",
      "\n",
      "==== Output Layer ====\n",
      "\n",
      "bert.pooler.dense.weight                                  (768, 768)\n",
      "bert.pooler.dense.bias                                        (768,)\n",
      "classifier.weight                                           (2, 768)\n",
      "classifier.bias                                                 (2,)\n"
     ]
    }
   ],
   "source": [
    "# Get all of the model's parameters as a list of tuples.\n",
    "params = list(model.named_parameters())\n",
    "\n",
    "print('The BERT model has {:} different named parameters.\\n'.format(len(params)))\n",
    "\n",
    "print('==== Embedding Layer ====\\n')\n",
    "\n",
    "for p in params[0:5]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
    "\n",
    "print('\\n==== First Transformer ====\\n')\n",
    "\n",
    "for p in params[5:21]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
    "\n",
    "print('\\n==== Output Layer ====\\n')\n",
    "\n",
    "for p in params[-4:]:\n",
    "    print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T16:44:06.458142Z",
     "start_time": "2020-05-21T16:44:06.441063Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Original:  مكااااان الفندق بعيد قليلا لكنه كذلك مقبول ومعروف عند التاكسيات وممكن الوصول له عبر الجي بي اس\n",
      "Tokenized:  ['م', '##كا', '##ا', '##ا', '##ا', '##ان', 'الف', '##ند', '##ق', 'ب', '##عيد', 'ق', '##ليل', '##ا', 'لكنه', 'كذلك', 'م', '##قب', '##ول', 'ومع', '##روف', 'عند', 'ال', '##تا', '##كس', '##يات', 'و', '##مم', '##كن', 'الوصول', 'له', 'عبر', 'ال', '##جي', 'بي', 'اس'] \n",
      "\n",
      " Original:  I was looking for banana tempura for dessert and they dont have.\n",
      "Tokenized:  ['i', 'was', 'looking', 'for', 'banana', 'temp', '##ura', 'for', 'desse', '##rt', 'and', 'they', 'dont', 'have', '.'] \n",
      "\n",
      " Original:  El trato muy bueno, muy amables, el arroz exquisito y a mitad de precio, lo recomiendo.\n",
      "Tokenized:  ['el', 'trato', 'muy', 'bueno', ',', 'muy', 'ama', '##bles', ',', 'el', 'arroz', 'ex', '##quis', '##ito', 'y', 'a', 'mitad', 'de', 'precio', ',', 'lo', 'rec', '##omi', '##endo', '.'] \n",
      "\n",
      " Original:  Вот идем мы с мамой после театра довольные и счастливые, и на свою беду решаю отвести ее в длинный хвост.\n",
      "Tokenized:  ['вот', 'иде', '##м', 'мы', 'с', 'ма', '##мои', 'после', 'театра', 'до', '##во', '##льные', 'и', 'с', '##час', '##т', '##лив', '##ые', ',', 'и', 'на', 'свою', 'б', '##ед', '##у', 'р', '##еш', '##аю', 'отв', '##ести', 'ее', 'в', 'д', '##лин', '##ныи', 'х', '##во', '##ст', '.'] \n",
      "\n",
      " Original:  各种APP按钮的排序也是混乱的，\n",
      "Tokenized:  ['各', '种', 'app', '按', '钮', '的', '排', '序', '也', '是', '混', '乱', '的', '，'] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# lang = \"en\"\n",
    "for lang in LANGS:\n",
    "    sample_sentence = random.choice(data[lang][\"train\"])[0]\n",
    "\n",
    "    # Print the original sentence.\n",
    "    print(' Original: ', sample_sentence)\n",
    "\n",
    "    # Print the sentence split into tokens.\n",
    "    print('Tokenized: ', tokenizer.tokenize(sample_sentence),\"\\n\")\n",
    "\n",
    "#     # Print the sentence mapped to token ids.\n",
    "#     print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sample_sentence)))\n",
    "\n",
    "    # Print the encoded sentence (with [CLS], [SEP] appended).\n",
    "#     print('Encoded IDs: ', tokenizer.encode(sample_sentence))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune mBERT w/ the SA data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T16:44:17.557066Z",
     "start_time": "2020-05-21T16:44:17.508566Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "from transformers import get_linear_schedule_with_warmup\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler\n",
    "\n",
    "MAX_LEN = 20\n",
    "NUM_EPOCHS = 2\n",
    "BATCH_SIZE = 32\n",
    "def get_optimizer(model, total_steps):\n",
    "    optim = AdamW(model.parameters(), lr = 5e-5)\n",
    "    scheduler = get_linear_schedule_with_warmup(optim, \n",
    "                                            num_warmup_steps = 0,\n",
    "                                            num_training_steps = total_steps)\n",
    "    return optim, scheduler\n",
    "\n",
    "\n",
    "def get_tensordata(sentences, labels):\n",
    "    input_ids,attention_masks = [], []\n",
    "    for sentence, label in zip(sentences, labels):\n",
    "        encoded_dict = tokenizer.encode_plus(sentence, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')\n",
    "        input_ids.append(encoded_dict['input_ids'])\n",
    "        attention_masks.append(encoded_dict['attention_mask'])\n",
    "    input_ids = torch.cat(input_ids, dim=0)\n",
    "    attention_masks = torch.cat(attention_masks, dim=0)\n",
    "    labels = [int((l == \"pos\")) for l in labels]\n",
    "    labels = torch.tensor(labels)\n",
    "    dataset = TensorDataset(input_ids, attention_masks, labels)\n",
    "    dataloader = DataLoader(\n",
    "                dataset,  # The training samples.\n",
    "                sampler = RandomSampler(dataset), # Select batches randomly\n",
    "                batch_size = BATCH_SIZE # Trains with this batch size.\n",
    "            )\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "def run_model(model, data:dict, lang_train:list, lang_test:str, bool_print=False) -> float:\n",
    "    \n",
    "    def train(epoch, bool_valid=True):\n",
    "        if bool_print:\n",
    "            print(f'\\n======== Epoch {epoch+1} / {NUM_EPOCHS} ========\\nTraining...')\n",
    "        total_train_loss = 0\n",
    "        t0 = time.time()\n",
    "\n",
    "        model.train()\n",
    "        for step, batch in enumerate(data_train):\n",
    "\n",
    "            # Progress update every 50 batches.\n",
    "            if bool_print and step % 40 == 0 and not step == 0:\n",
    "                # Calculate elapsed time in minutes.\n",
    "                elapsed = format_time(time.time() - t0)\n",
    "\n",
    "                # Report progress.\n",
    "                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(data_train), elapsed))\n",
    "            batch_input_ids = batch[0].to(device)\n",
    "            batch_att_mask = batch[1].to(device)\n",
    "            batch_labels = batch[2].to(device)\n",
    "            model.zero_grad()  \n",
    "            loss, logits = model(batch_input_ids, \n",
    "                                 token_type_ids=None, \n",
    "                                 attention_mask=batch_att_mask, \n",
    "                                 labels=batch_labels)\n",
    "            total_train_loss += loss.item()\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "            scheduler.step()\n",
    "        avg_train_loss = total_train_loss / len(data_train)\n",
    "        training_time = format_time(time.time() - t0)\n",
    "\n",
    "        if bool_print:\n",
    "            print(\"\\n  Average training loss: {0:.2f}\".format(avg_train_loss))\n",
    "            print(\"  Training epcoh took: {:}\".format(training_time))\n",
    "\n",
    "        if bool_valid and data_test != None:\n",
    "            t0 = time.time()\n",
    "            # Put the model in evaluation mode--the dropout layers behave differently\n",
    "            # during evaluation.\n",
    "            model.eval()\n",
    "            # Tracking variables \n",
    "            total_eval_accuracy = 0\n",
    "            total_eval_loss = 0\n",
    "            nb_eval_steps = 0\n",
    "\n",
    "            # Evaluate data for one epoch\n",
    "            for batch in data_test:\n",
    "\n",
    "                # Unpack this training batch from our dataloader. \n",
    "                # As we unpack the batch, we'll also copy each tensor to the GPU using \n",
    "                # the `to` method.\n",
    "                batch_input_ids = batch[0].to(device)\n",
    "                batch_att_mask = batch[1].to(device)\n",
    "                batch_labels = batch[2].to(device)\n",
    "\n",
    "                # Tell pytorch not to bother with constructing the compute graph during\n",
    "                # the forward pass, since this is only needed for backprop (training).\n",
    "                with torch.no_grad(): \n",
    "                    (loss, logits) = model(batch_input_ids, \n",
    "                                           token_type_ids=batch_att_mask, \n",
    "                                           attention_mask=None,\n",
    "                                           labels=batch_labels)\n",
    "\n",
    "                # Accumulate the validation loss.\n",
    "                total_eval_loss += loss.item()\n",
    "\n",
    "                # Move logits and labels to CPU\n",
    "                logits = logits.detach().cpu().numpy()\n",
    "                label_ids = batch_labels.to('cpu').numpy()\n",
    "\n",
    "                # Calculate the accuracy for this batch of test sentences, and\n",
    "                # accumulate it over all batches.\n",
    "                total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
    "\n",
    "\n",
    "            # Report the final accuracy for this validation run.\n",
    "            avg_val_accuracy = total_eval_accuracy / len(data_test)\n",
    "\n",
    "            # Calculate the average loss over all of the batches.\n",
    "            avg_val_loss = total_eval_loss / len(data_test)\n",
    "\n",
    "            # Measure how long the validation run took.\n",
    "            validation_time = format_time(time.time() - t0)\n",
    "            if bool_print:\n",
    "                print(\"  Accuracy: {0:.2f}\".format(avg_val_accuracy))\n",
    "                print(\"  Validation Loss: {0:.2f}\".format(avg_val_loss))\n",
    "                print(\"  Validation took: {:}\".format(validation_time))\n",
    "        else:\n",
    "            avg_val_loss, avg_val_accuracy, validation_time = None, None, None\n",
    "        log = {\n",
    "            'epoch': epoch + 1,\n",
    "            'Training Loss': avg_train_loss,\n",
    "            'Valid Loss': avg_val_loss,\n",
    "            'Valid Acc': avg_val_accuracy,\n",
    "        }\n",
    "        return log\n",
    "\n",
    "    sentences_train, y_train = [], []\n",
    "    for lang in lang_train:\n",
    "        _sentences, _labels = zip(*data[lang][\"train\"])\n",
    "        sentences_train += _sentences\n",
    "        y_train += _labels\n",
    "    sentences_test, y_test = zip(*data[lang_test][\"test\"])\n",
    "    \n",
    "    data_train, data_test = get_tensordata(sentences_train, y_train), get_tensordata(sentences_test, y_test)\n",
    "    \n",
    "    total_steps = len(data_train) * NUM_EPOCHS\n",
    "    optim, scheduler = get_optimizer(model, total_steps)\n",
    "    \n",
    "    \n",
    "    training_logs = []\n",
    "    max_acc = 0\n",
    "    for epoch in range(NUM_EPOCHS):\n",
    "        log = train(epoch, bool_valid=True)\n",
    "        max_acc = max(max_acc, log['Valid Acc'])\n",
    "        training_logs.append(log)\n",
    "    return max_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-21T16:46:34.648668Z",
     "start_time": "2020-05-21T16:44:18.768774Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train: ar, test: ar, acc: 0.841\n",
      "train: en, test: en, acc: 0.848\n",
      "train: es, test: es, acc: 0.878\n",
      "train: ru, test: ru, acc: 0.811\n",
      "train: zh, test: zh, acc: 0.830\n"
     ]
    }
   ],
   "source": [
    "res_dict = {}\n",
    "for lang in LANGS:\n",
    "    max_acc = run_model(model, data, [lang], lang)\n",
    "    res_dict[lang] = max_acc\n",
    "    print(f\"train: {lang}, test: {lang}, acc: {max_acc:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zero-shot Cross-lingual Transfer Experiments + Low-resource setting\n",
    "\n",
    "Now we control the number of training samples and compare how well cross-lingual transfer works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-13T09:19:47.362238Z",
     "start_time": "2020-05-13T09:19:47.322335Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>#train</th>\n",
       "      <th>#test</th>\n",
       "      <th>train-pos%</th>\n",
       "      <th>test-pos%</th>\n",
       "      <th>sample</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ar</th>\n",
       "      <td>1333</td>\n",
       "      <td>1145</td>\n",
       "      <td>0.60</td>\n",
       "      <td>0.57</td>\n",
       "      <td>الغرفة رائحتها كريهة جدًا</td>\n",
       "      <td>neg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>en</th>\n",
       "      <td>1333</td>\n",
       "      <td>555</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.72</td>\n",
       "      <td>This place is a must visit!</td>\n",
       "      <td>pos</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>es</th>\n",
       "      <td>1333</td>\n",
       "      <td>650</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.68</td>\n",
       "      <td>Perfecto... Como siempre ...</td>\n",
       "      <td>pos</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ru</th>\n",
       "      <td>1333</td>\n",
       "      <td>865</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.68</td>\n",
       "      <td>Спасибо вам огромное.</td>\n",
       "      <td>pos</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>zh</th>\n",
       "      <td>1333</td>\n",
       "      <td>529</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.59</td>\n",
       "      <td>看电影也相当爽，</td>\n",
       "      <td>pos</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    #train  #test  train-pos%  test-pos%                        sample label\n",
       "ar    1333   1145        0.60       0.57     الغرفة رائحتها كريهة جدًا   neg\n",
       "en    1333    555        0.67       0.72   This place is a must visit!   pos\n",
       "es    1333    650        0.72       0.68  Perfecto... Como siempre ...   pos\n",
       "ru    1333    865        0.76       0.68         Спасибо вам огромное.   pos\n",
       "zh    1333    529        0.57       0.59                      看电影也相当爽，   pos"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data_sample = subset_data(data)\n",
    "print_data_stats(data_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-05-13T17:28:22.103Z"
    }
   },
   "outputs": [],
   "source": [
    "res_subset_dict = {}\n",
    "for lang_test in LANGS:\n",
    "    res_subset_dict[lang_test] = {}\n",
    "    for lang_train in LANGS:   \n",
    "        max_acc = run_model(model, data_sample, [lang_train], lang_test)\n",
    "        res_subset_dict[lang_test][lang_train] = max_acc\n",
    "        print(f\"train: {lang_train}, test: {lang_test}, acc: {max_acc:.3f}\")\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-13T08:33:54.093279Z",
     "start_time": "2020-05-13T08:33:54.074662Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ar</th>\n",
       "      <th>en</th>\n",
       "      <th>es</th>\n",
       "      <th>ru</th>\n",
       "      <th>zh</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ar</th>\n",
       "      <td>0.863</td>\n",
       "      <td>0.835</td>\n",
       "      <td>0.882</td>\n",
       "      <td>0.824</td>\n",
       "      <td>0.787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>en</th>\n",
       "      <td>0.861</td>\n",
       "      <td>0.849</td>\n",
       "      <td>0.882</td>\n",
       "      <td>0.827</td>\n",
       "      <td>0.808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>es</th>\n",
       "      <td>0.861</td>\n",
       "      <td>0.861</td>\n",
       "      <td>0.875</td>\n",
       "      <td>0.833</td>\n",
       "      <td>0.800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ru</th>\n",
       "      <td>0.836</td>\n",
       "      <td>0.849</td>\n",
       "      <td>0.855</td>\n",
       "      <td>0.829</td>\n",
       "      <td>0.802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>zh</th>\n",
       "      <td>0.843</td>\n",
       "      <td>0.839</td>\n",
       "      <td>0.848</td>\n",
       "      <td>0.827</td>\n",
       "      <td>0.800</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       ar     en     es     ru     zh\n",
       "ar  0.863  0.835  0.882  0.824  0.787\n",
       "en  0.861  0.849  0.882  0.827  0.808\n",
       "es  0.861  0.861  0.875  0.833  0.800\n",
       "ru  0.836  0.849  0.855  0.829  0.802\n",
       "zh  0.843  0.839  0.848  0.827  0.800"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(pd.DataFrame.from_dict(res_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-13T09:21:52.007990Z",
     "start_time": "2020-05-13T09:21:52.004481Z"
    }
   },
   "source": [
    "# Activity: Explain your observation from the table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- observation 1:\n",
    "- observation 2:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cross-lingual Transfer (use all data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-13T17:01:09.415054Z",
     "start_time": "2020-05-13T16:50:14.607726Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train: ['ar', 'en', 'es', 'ru', 'zh'], test: ar, acc: 0.852\n",
      "train: ['ar', 'en', 'es', 'ru', 'zh'], test: en, acc: 0.856\n",
      "train: ['ar', 'en', 'es', 'ru', 'zh'], test: es, acc: 0.854\n",
      "train: ['ar', 'en', 'es', 'ru', 'zh'], test: ru, acc: 0.807\n",
      "train: ['ar', 'en', 'es', 'ru', 'zh'], test: zh, acc: 0.793\n"
     ]
    }
   ],
   "source": [
    "res_subset_dict = {}\n",
    "\n",
    "for lang_test in LANGS:   \n",
    "    max_acc = run_model(model, data, LANGS, lang_test)\n",
    "    res_subset_dict[lang_test] = max_acc\n",
    "    print(f\"train: {LANGS}, test: {lang_test}, acc: {max_acc:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "train: ar, test: ar, acc: 0.841\n",
    "train: en, test: en, acc: 0.848\n",
    "train: es, test: es, acc: 0.878\n",
    "train: ru, test: ru, acc: 0.811\n",
    "train: zh, test: zh, acc: 0.830"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:etc]",
   "language": "python",
   "name": "conda-env-etc-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "nbTranslate": {
   "displayLangs": [
    "*"
   ],
   "hotkey": "alt-t",
   "langInMainMenu": true,
   "sourceLang": "en",
   "targetLang": "fr",
   "useGoogleTranslate": true
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
